(ns splendid.ml.nb
  (:use splendid.ml.text))

(defn train-class [class texts & {:keys [normalize-fn split-text-fn bagging-fn]
                                  :or {normalize-fn  normalize-text
                                       split-text-fn split-text-into-words-by-whitespace
                                       bagging-fn    bag-of-words}}]
  (let [bags (->> texts
                  (map normalize-fn)
                  (map split-text-fn)
                  (mapv bagging-fn))
        class-bag (apply merge-with + bags)
        count (reduce + (vals class-bag))]
    {:class class
     :bags  bags ; not used
     :class-total-bag class-bag
     :class-count count ; not used
     :class-occurrences (reduce (fn [m [k v]]
                                  (assoc m k (/ v count)))
                                {} class-bag)}))

(defn train [classes-texts-map]
  (let [texts-per-class (zipmap (keys classes-texts-map)
                                (map count (vals classes-texts-map)))
        total-texts-count (reduce + (vals texts-per-class))
        class-probs (zipmap (keys texts-per-class)
                            (map #(/ % total-texts-count) (vals texts-per-class)))
        training-bags (map (fn [[class texts]]
                             (train-class class texts))
                           classes-texts-map)
        total (apply merge-with +
                     (map :class-total-bag training-bags))
        count (reduce + (vals total))]
    {:classes (keys classes-texts-map)
     :texts-per-class texts-per-class ; not used
     :class-probs class-probs
     :training-bags training-bags
     :total-bag total ; not used
     ;; not used:
     :total-occurrences (reduce-kv #(assoc %1 %2 (/ %3 count))
                                   (sorted-map) total)}))

(defn classify [nb text]
  (let [classes (:classes nb)
        class-bags (:training-bags nb)
        words (->> text
                   normalize-text
                   split-text-into-words-by-whitespace)
        word-set (set words)
        word-zero-probs (zipmap word-set (repeat (count word-set) 1/1000))
        class-probs (:class-probs nb)
        probs (atom {})]
    (doseq [bag class-bags]
      (let [c (:class bag)
            frequencies (:class-occurrences bag)
            class-prob  (get class-probs c)
            class-words (select-keys frequencies words)
            class-words (merge word-zero-probs class-words)            
            in-class-probs (vals class-words)]
        (swap! probs assoc c (apply * class-prob in-class-probs))
        (do
          (prn c frequencies)
          (prn class-words)
          (prn in-class-probs (apply * in-class-probs))
          (prn class-prob in-class-probs (apply * class-prob in-class-probs))
          (println))))
    (let [all-probs (reduce + (vals @probs))]
      (println "All probs:" all-probs)
      (reduce-kv (fn [m c prob]
                   (assoc m c (double (/ prob all-probs))))
                 (sorted-map)
                 @probs))))
